#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging

_logger = logging.getLogger(__name__)


__all__ = ["Position"]


class Position(object):
    def __init__(
        self,
        full_symbol: str,
        average_price: float,
        size: int,
    ) -> None:
        """
        Position includes zero/closed security
        """
        ## TODO: add cumulative_commission, long_trades, short_trades, round_trip etc
        self.full_symbol: str = full_symbol
        # average price includes commission
        self.average_price: float = average_price
        self.size: int = size
        self.realized_pnl: float = 0.0
        self.unrealized_pnl: float = 0.0
        self.account: str = ""

    def get_current_pnl(self) -> tuple[float, float]:
        return self.realized_pnl, self.unrealized_pnl

    def mark_to_market(self, last_price: float, multiplier: float) -> None:
        """
        given new market price, update the position
        """
        # if long or size > 0, pnl is positive if last_price > average_price
        # else if short or size < 0, pnl is positive if last_price < average_price
        self.unrealized_pnl = (last_price - self.average_price) * self.size * multiplier

    def on_fill(self, fill_event, multiplier: float, parent_name: str) -> None:  # type: ignore
        """
        adjust average_price and size according to new fill/trade/transaction
        """
        if self.full_symbol != fill_event.full_symbol:
            _logger.error(
                "%s Position symbol %s and fill event symbol %s do not match. "
                % (parent_name, self.full_symbol, fill_event.full_symbol)
            )

        if self.size > 0:  # existing long
            if fill_event.fill_size > 0:  # long more
                self.average_price = (
                    self.average_price * self.size
                    + fill_event.fill_price * fill_event.fill_size
                    + fill_event.commission / multiplier
                ) / (self.size + fill_event.fill_size)
            else:  # flat long
                # _logger.info(f'{parent_name} flat long realized_pnl {self.realized_pnl}, avg {self.average_price}, fill {fill_event.fill_price}, {fill_event.fill_size}, {multiplier}, {fill_event.commission}')
                if abs(self.size) >= abs(fill_event.fill_size):  # stay long
                    self.realized_pnl += (
                        self.average_price - fill_event.fill_price
                    ) * fill_event.fill_size * multiplier - fill_event.commission
                else:  # flip to short
                    self.realized_pnl += (
                        fill_event.fill_price - self.average_price
                    ) * self.size * multiplier - fill_event.commission
                    self.average_price = fill_event.fill_price
        elif self.size < 0:  # existing short
            if fill_event.fill_size < 0:  # short more
                self.average_price = (
                    self.average_price * self.size
                    + fill_event.fill_price * fill_event.fill_size
                    + fill_event.commission / multiplier
                ) / (self.size + fill_event.fill_size)
            else:  # flat short
                # _logger.info(f'{parent_name} flat short realized_pnl {self.realized_pnl}, avg {self.average_price}, fill {fill_event.fill_price}, {fill_event.fill_size}, {multiplier}, {fill_event.commission}')
                if abs(self.size) >= abs(fill_event.fill_size):  # stay short
                    self.realized_pnl += (
                        self.average_price - fill_event.fill_price
                    ) * fill_event.fill_size * multiplier - fill_event.commission
                else:  # flip to long
                    self.realized_pnl += (
                        fill_event.fill_price - self.average_price
                    ) * self.size * multiplier - fill_event.commission
                    self.average_price = fill_event.fill_price
        else:  # no position
            self.average_price = (
                fill_event.fill_price
                + fill_event.commission / multiplier / fill_event.fill_size
            )

        self.size += fill_event.fill_size

        _logger.info(
            f"{parent_name} Position Fill: sym {self.full_symbol}, avg price {self.average_price}, fill price {fill_event.fill_price}, fill size {fill_event.fill_size}, after size {self.size}, close pnl {self.realized_pnl}"
        )
